iT邦幫忙

2023 iThome 鐵人賽

DAY 24
0
AI & Data

深度學習概念和應用(PyTorch)系列 第 24

DAY24 mnist實作 – 模型訓練、結果輸出

  • 分享至 

  • xImage
  •  

定義模型後,可以開始

model = CNN() 
epochs = 100 
batch_size = 300
lr = 1e-3

建立模組、定義訓練次數、批次大小和學習率

  xt = train_data.data[batch_CNN].detach() 
  yt = train_data.train_labels[batch_CNN].detach() 
  pred = model(xt)
  pred_labels = torch.argmax(pred,dim=1) 

產生批次後,y軸取得標籤資訊,在找出預測的最大值作為結果

  acc_ = 100.0 * (pred_labels == yt).sum() / batch_size 
  print('Current training accuracy: ', acc_.item()) 

計算準確率並預測

plt.figure(figsize=(10,7))
plt.xlabel("Training Epochs", fontsize=12)
plt.ylabel("Training accuracy", fontsize=12)
plt.plot(acc_CNN)

https://ithelp.ithome.com.tw/upload/images/20231009/20163187wI8RsQvnNj.png
可以畫出每次epoch的準確度,從張圖我們可以知道訓練週期越到後面,模型的準確率也越高。

for i in range (10):
  x = test_data.data[test_id[i]]     
  plt.imshow(x)
  print('\n預測數字是:', pred_ind[i])  

輸出測試資料,可以看到每張圖型和預測結果
https://ithelp.ithome.com.tw/upload/images/20231009/20163187urEH8CuS0B.png


上一篇
DAY23 mnist實作 - 數據集處理
下一篇
DAY25 CycleGAN循環式生成對抗網路1
系列文
深度學習概念和應用(PyTorch)30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言